-
Notifications
You must be signed in to change notification settings - Fork 579
RFC: Add TensorForest classifier and regressor to canned estimators #3
RFC: Add TensorForest classifier and regressor to canned estimators #3
Conversation
|
Adding the overview information. This review will remain open for comment until the end of Monday, July 16th (allowing for public holidays). TensorForest Estimator
ObjectiveIn this doc, we discuss the TensorForest Estimator API, which enables a user to create |
martinwicke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main comment would be that we should start with the minimum set of parameters that gives users the flexibility they need. It seems some of the parameters are not that useful, could we remove them to make the API simpler?
Some questions:
- Could we have benchmarks for this?
- Could you discuss whether there are efficiencies to be had for whole batch training? We spent a lot of time on such questions for the boosted tree Estimator, and I don't think we need to go into that much detail, but I would like to know whether there are obvious improvements we can make. Sometimes, this type of thing can influence the API (e.g., but requiring a separate pretraining input or something).
rfcs/20180626-tensor-forest.md
Outdated
| * **label_vocabulary:** A list of strings represents possible label values. If given, labels must be string type and have any value in `label_vocabulary`. If it is not given, that means labels are already encoded as integer or float within [0, 1] for `n_classes=2` and encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. | ||
| * **n_trees:** The number of trees to create. Defaults to 100. There usually isn't any accuracy gain from using higher values. | ||
| * **max_nodes:** Defaults to 10,000. No tree is allowed to grow beyond max_nodes nodes, and training stops when all trees in the forest are this large. | ||
| * **num_splits_to_consider:** Defaults to `sqrt(num_features)` capped to be between 10 and 1000. In the extremely randomized tree training algorithm, only this many potential splits are evaluated for each tree node. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the 10 and 1000 boundaries universally accepted?
Nit, I would say "clipped", to my ear, "capped" only works for the upper bound.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now i just borrowed from the original contrib implementation. it is not universal though. not sure why origin author implemented this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really, sklearn ExtraTree is using sqrt(num_features) as default.
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L1192-L1202
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so maybe i should remove this clip heuristic ?
rfcs/20180626-tensor-forest.md
Outdated
| * **max_nodes:** Defaults to 10,000. No tree is allowed to grow beyond max_nodes nodes, and training stops when all trees in the forest are this large. | ||
| * **num_splits_to_consider:** Defaults to `sqrt(num_features)` capped to be between 10 and 1000. In the extremely randomized tree training algorithm, only this many potential splits are evaluated for each tree node. | ||
| * **split_after_samples:** Defaults to 250. In our online version of extremely randomized tree training, we pick a split for a node after it has accumulated this many training samples. | ||
| * **bagging_fraction:** If less than 1.0, then each tree sees only a different, random sampled (without replacement), bagging_fraction sized subset of the training data. Defaults to 1.0 (no bagging) because it fails to give any accuracy improvement our experiments so far. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this gives no improvement, can we remove this argument?
We can always add stuff back, but we can never take it away (except at major versions) so we should be conservative in what we add.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for now i just borrowed from the original contrib implementation, but thanks for your suggestion, i guess i can use the benchmark tool to find out whether the original claim is valid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since the origin paper did have some numbers suggestion bootstrapping is not helping, i'll remove it form the api for now.
|
For benchmark, yeah we might use this https://www.openml.org/search?q=ExtraTrees&type=flow
In this case not really, since the trees in tensor forest we are using are Hoeffding Tree, which is a incremental tree, so we don't require full batch training |
874907c to
16d1708
Compare
500ecaa to
58d0bf8
Compare
e84b8d9 to
fc56dbc
Compare
|
@nataliaponomareva @martinwicke I think we're good to merge this now. Waiting for your LGTM and I'll merge. |
|
Can we reflect the discussion notes somewhere here? Could be as a link to a doc, even in the comment thread. I just don't want them lost. @tanzhenyu |
|
Agreed, before we've linked them at the bottom of the RFC. Either that or including them in this PR thread would also work, and we'll link the PR discussion at the bottom of the RFC. |
|
Talked with Edd offline, he will post it within the ready-to-pushed rfc. |
|
Notes from the review committee meeting on 2018-08-07:
|
|
|
||
| - Simplified code with only limited subset of features (obviously, excluding all the experimental ones) | ||
| - New estimator interface, support for new feature columns and losses | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also copy over this from the doc
We will try to reuse as much code from canned boosted trees as possible (proto, inference etc)
c363771 to
424fe65
Compare
rfcs/20180626-tensor-forest.md
Outdated
| ### Interface | ||
| ### TensorForestClassifier | ||
|
|
||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ```python to get the syntax highlighting?
rfcs/20180626-tensor-forest.md
Outdated
|
|
||
| ### TensorForestRegressor | ||
|
|
||
| ``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ```python to get the syntax highlighting?
rfcs/20180626-tensor-forest.md
Outdated
| 4. Otherwise, `(x_i, y_i)` is used to update the statistics of every split in the growing statistics of leaf `l_i`. If leaf `l_i` has now seen `split_after_samples` data points since creating all of its potential splits, the split with the best score is chosen, and the tree structure is grown. | ||
|
|
||
|
|
||
| ## BenchMark |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark
rfcs/20180626-tensor-forest.md
Outdated
| |Covertype| 581k| 54| 7| 83.0| 85.0| | ||
| |HiGGS| 11M| 28| 2| 70.9| 71.7| | ||
|
|
||
| With single machine training, TensorForest finishes much faster on big dataset like HIGGS, takes about one percent of the time scikit-lean required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm where is time n this table? It is just performance metrics right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, it's just performance metrics, i took it from the workshop paper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But what i am saying that this statement "much faster on big dataset" is not substantiated by this table. You either keep the table and say that it is from resource A, demonstrating that the quality is on par with scikit learn, and removing the statement that says that it trains faster. Or add a reference to the resource which states that it trains faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see.. i'll add a source to the workshop paper, as the claim was also from the paper.
Update TFX notebook RFC after comments
Update 20191016-dlpack-support.md
Callback changes based on discussion
|
Hi, may I know whether the tensor forest package is still supported under tensorflow 2.0.x? Many thanks. |
Since tree algorithm is one of the most popular algorithm used in kaggle competition
and we already have a contrib project tensor_forest and people like them. If would be beneficial to move them inside of canned estimators.
cc: @nataliaponomareva